console.log('Hello TensorFlow');

const classNames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
var canvas, ctx, saveButton, clearButton, rawImage;
var pos = {x: 0, y: 0};
var model;

async function showExamples(data) {
  // Create a container in the visor
  // 用tfvis创建一个tab
  const surface =
    tfvis.visor().surface({name: '训练数据示例', tab: '训练数据'});

  // Get the examples
  // 获取10个要训练的数据
  const examples = data.nextTestBatch(10);
  const numExamples = examples.xs.shape[0];

  // Create a canvas element to render each example
  // 创建canvas将每个数据的图像绘制出来
  for (let i = 0; i < numExamples; i++) {
    const imageTensor = tf.tidy(() => {
      // Reshape the image to 28x28 px
      // 分割图集，将图像重塑为28*28大小的图片
      return examples.xs
        .slice([i, 0], [1, examples.xs.shape[1]])
        .reshape([28, 28, 1]);
    });

    const canvas = document.createElement('canvas');
    canvas.width = 28;
    canvas.height = 28;
    canvas.style = 'margin: 4px;';
    // 将图片画到画布上
    await tf.browser.toPixels(imageTensor, canvas);
    // 将画布添加到visor中
    surface.drawArea.appendChild(canvas);
    // 销毁图片，释放内存
    imageTensor.dispose();
  }
}

function getModel() {
  // 创建一个线性堆叠模型
  const model = tf.sequential();
  // 定义图片属性 28*28
  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  // 深度为 1，这是因为我们的图片只有1个颜色
  const IMAGE_CHANNELS = 1;

  // In the first layer of our convolutional neural network we have
  // to specify the input shape. Then we specify some parameters for
  // the convolution operation that takes place in this layer.
  // 给模型添加卷积层
  model.add(tf.layers.conv2d({
    // 这个数据的形状将回流入模型的第一层
    inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
    // 划过卷积层过滤窗口的数量将会被应用到输入数据中去。
    // 这里，我们设置了 kernalSize 的值为5，也就是指定了一个5 x 5的卷积窗口。
    kernelSize: 5,
    // 这个 kernelSize 的过滤窗口的数量将会被应用到输入数据中，我们这里将8个过滤器应用到数据中
    filters: 8,
    // 即滑动窗口每一步的步长。比如每当过滤器移动过图片时将会由多少像素的变化。
    // 这里，我们指定其步长为1，这意味着每一步都是1像素的移动。
    strides: 1,
    // 这个 activation 函数将会在卷积完成之后被应用到数据上。
    // 在这个例子中，我们应用了 relu 函数，这个函数在机器学习中是一个非常常见的激活函数。
    activation: 'relu',
    // 这个方法对于训练动态的模型是非常重要的，
    // 他被用于任意地初始化模型的 weights。
    // 我们这里将不会深入细节来讲，
    // 但是 VarianceScaling （即这里用的）真的是一个初始化非常好的选择。
    kernelInitializer: 'varianceScaling'
  }));

  // The MaxPooling layer acts as a sort of downsampling using max values
  // in a region instead of averaging.
  // 给模型添加池化层（pooling layer）
  // 这一层将会通过在每个滑动窗口中计算最大值来降频取样得到结果。
  // 注意：因为 poolSize 和 strides 都是2x2，
  // 所以池化层空口将会完全不会重叠。这也就意味着池化层将会把激活的大小从上一层减少一半。
  model.add(tf.layers.maxPooling2d({
    // 这个滑动池窗口的数量将会被应用到输入的数据中。
    // 这里我们设置 poolSize为[2, 2]，
    // 所以这就意味着池化层将会对输入数据应用2x2的窗口。
    poolSize: [2, 2],
    // 这个池化层的步长大小。
    // 比如，当每次挪开输入数据时窗口需要移动多少像素。
    // 这里我们指定 strides为[2, 2]，
    // 这就意味着过滤器将会以在水平方向和竖直方向上同时移动2个像素的方式来划过图片。
    strides: [2, 2]
  }));

  // Repeat another conv2d + maxPooling stack.
  // Note that we have more filters in the convolution.
  // 添加第二层卷积层
  // 重复使用层结构是神经网络中的常见模式。我们添加第二个卷积层到模型
  // 我们没有指定 inputShape，因为它可以从前一层的输出形状中推断出来
  model.add(tf.layers.conv2d({
    kernelSize: 5,
    // 将滤波器数量从8增加到16。
    filters: 16,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
  }));
  // 添加第二层池化层
  model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));

  // Now we flatten the output from the 2D filters into a 1D vector to prepare
  // it for input into our last layer. This is common practice when feeding
  // higher dimensional data to a final classification output layer.
  // 接下来，我们添加一个 flatten 层，将前一层的输出平铺到一个向量中：
  model.add(tf.layers.flatten());

  // Our last layer is a dense layer which has 10 output units, one for each
  // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
  // 输出数量
  const NUM_OUTPUT_CLASSES = 10;
  // 最后，让我们添加一个 dense 层（也称为全连接层），
  // 它将执行最终的分类。
  // 在 dense 层前先对卷积+池化层的输出执行 flatten 也是神经网络中的另一种常见模式：
  model.add(tf.layers.dense({
    // 激活输出的数量。由于这是最后一层，
    // 我们正在做10个类别的分类任务（数字0-9），因此我们在这里使用10个 units。
    units: NUM_OUTPUT_CLASSES,
    // 我们将对 dense 层使用与卷积层相同的 VarianceScaling 初始化策略。
    kernelInitializer: 'varianceScaling',
    // 分类任务的最后一层的激活函数通常是 softmax。
    // Softmax 将我们的10维输出向量归一化为概率分布，使得我们10个类中的每个都有一个概率值。
    activation: 'softmax'
  }));

  // Choose an optimizer, loss function and accuracy metric,
  // then compile and return the model
  const optimizer = tf.train.adam();
  // 为了编译模型，我们传入一个由优化器，损失函数和一系列评估指标（这里只是'精度'）组成的配置对象
  model.compile({
    optimizer: optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  });

  return model;
}

async function train(model, data) {
  // 为迭代回调设置以下指标 'loss', 'val_loss', 'acc', 'val_acc' 显示到页面tab上
  // loss：训练集损失值
  // val_loss:测试集损失值
  // accuracy:训练集准确率
  // val_accruacy:测试集准确率
  const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
  // 创建一个容器，用于显示训练过程
  const container = {
    name: '模型训练', tab: '模型', styles: {height: '1000px'}
  };
  // 使用tfvis.show.fitCallbacks()设置回调。
  // 使用上面定义的容器和度量作为参数。
  const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
  // 设置训练批次数量，以及大小
  const BATCH_SIZE = 512;
  const TRAIN_DATA_SIZE = 5500;
  const TEST_DATA_SIZE = 1000;

  // 获取训练批次并调整其大小
  const [trainXs, trainYs] = tf.tidy(() => {
    const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
    return [
      d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
      d.labels
    ];
  });
  // 获取测试批次并调整其大小。
  const [testXs, testYs] = tf.tidy(() => {
    const d = data.nextTestBatch(TEST_DATA_SIZE);
    return [
      d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
      d.labels
    ];
  });
  //开始训练
  return model.fit(trainXs, trainYs, {
    // 每个训练 batch 中包含多少个图像。之前我们在这里设置的BATCH_SIZE是 512
    batchSize: BATCH_SIZE,
    // 我们的评估度量（准确度）将在此数据集上计算（用测试数据来测试准确度）
    validationData: [testXs, testYs],
    // 迭代次数
    epochs: 10,
    // 是否再每轮迭代之前混洗数据
    shuffle: true,
    // 函数回调
    callbacks: fitCallbacks
  });
}

function doPrediction(model, data, testDataSize = 500) {
  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  const testData = data.nextTestBatch(testDataSize);
  const testxs = testData.xs.reshape([testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]);
  const labels = testData.labels.argMax(-1);
  const preds = model.predict(testxs).argMax(-1);

  testxs.dispose();
  return [preds, labels];
}

async function showAccuracy(model, data) {
  const [preds, labels] = doPrediction(model, data);
  const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
  const container = {name: '准确度', tab: '训练结果'};
  tfvis.show.perClassAccuracy(container, classAccuracy, classNames);

  labels.dispose();
}

async function showConfusion(model, data) {
  const [preds, labels] = doPrediction(model, data);
  const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
  const container = {name: '混淆矩阵', tab: 'Evaluation'};
  tfvis.render.confusionMatrix(container, {values: confusionMatrix, tickLabels: classNames});

  labels.dispose();
}

function draw(e) {
  if (e.buttons != 1) return;
  ctx.beginPath();
  ctx.lineWidth = 24;
  ctx.lineCap = 'round';
  ctx.strokeStyle = 'white';
  ctx.moveTo(pos.x, pos.y);
  setPosition(e);
  ctx.lineTo(pos.x, pos.y);
  ctx.stroke();
  rawImage.src = canvas.toDataURL('image/png');
}

function setPosition(e) {
  pos.x = e.clientX - 100;
  pos.y = e.clientY - 100;
}

function save() {
  var raw = tf.browser.fromPixels(rawImage, 1);
  var resized = tf.image.resizeBilinear(raw, [28, 28]);
  var tensor = resized.expandDims(0);

  var prediction = model.predict(tensor);
  var pIndex = tf.argMax(prediction, 1).dataSync();
  console.log('写下了：', classNames[pIndex])
}

function erase() {
  ctx.fillStyle = "black";
  ctx.fillRect(0, 0, 280, 280);
}

function init() {
  canvas = document.getElementById('canvas');
  rawImage = document.getElementById('canvasimg');
  ctx = canvas.getContext("2d");
  ctx.fillStyle = "black";
  ctx.fillRect(0, 0, 280, 280);
  canvas.addEventListener("mousemove", draw);
  canvas.addEventListener("mousedown", setPosition);
  canvas.addEventListener("mouseenter", setPosition);
  saveButton = document.getElementById('sb');
  saveButton.addEventListener("click", save);
  clearButton = document.getElementById('cb');
  clearButton.addEventListener("click", erase);
}

async function run() {
  // 创建MnistData对象
  const data = new MnistData();
  // 等待精灵图数据加载完成
  await data.load();
  // 渲染部分训练集中的图片到页面上
  await showExamples(data);
  // 创建训练模型
  model = getModel();
  // 将训练过程中每个迭代的变化以表格的形式展示到页面上
  tfvis.show.modelSummary({name: '模型结构', tab: '模型'}, model);
  // 开始训练
  await train(model, data);
  // 显示训练过程图
  await showAccuracy(model, data);
  await showConfusion(model, data);
  // 初始化手写工具
  init();
  alert("Training is done, try classifying your drawings!");
}

// 文档就绪函数
document.addEventListener('DOMContentLoaded', run);

